import argparse
import os
import torch
import torch.utils.data
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

from vae_models.conv_model import *
from vae_models.fc_model import *


class VAEBuilder:
    def get_arguments(self, args=None):
        parser = argparse.ArgumentParser(description='VAE MNIST Example')
        parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                            help='input batch size for training (default: 128)')
        parser.add_argument('--epochs', type=int, default=10, metavar='N',
                            help='number of epochs to train (default: 10)')
        parser.add_argument('--no-cuda', action='store_true', default=False,
                            help='enables CUDA training')
        parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status')
        parser.add_argument('--learn_beta', type=int, default=1, metavar='N',
                            help='learn beta. boolean.')
        parser.add_argument('--n_filters', type=int, default=32, metavar='N',
                            help='number of filters in the conv net.')
        parser.add_argument('--conv', type=int, default=1, metavar='N')
        parser.add_argument('--z_dim', type=int, default=20, metavar='N')
        parser.add_argument('--sigma', type=float, default=1, metavar='N')
        parser.add_argument('--log_dir', type=str, default='test', metavar='N', required=True)
        parser.add_argument('--dataset', type=str, default='MNIST', metavar='N')
        parser.add_argument('--sigma_mode', type=str, default='scalar_fixed', metavar='N',
                            help=['scalar_fixed', 'scalar', 'optimal', 'optimal_log', 'posthoc'])
        parser.add_argument('--distribution', type=str, default='gaussian', metavar='N',
                            help=['gaussian', 'beta', 'categorical', 'bernoulli'])
        parser.add_argument('--test_discretized', type=int, default=0, metavar='N')
        parser.add_argument('--detach_sigma_network', type=int, default=0, metavar='N')
        parser.add_argument('--lr', type=float, default=1e-3, metavar='N')
    
        args = parser.parse_args(args)
        args.cuda = not args.no_cuda and torch.cuda.is_available()
        os.makedirs('vae_logs/{}'.format(args.log_dir), exist_ok=True)
    
        device = torch.device("cuda" if args.cuda else "cpu")
        
        self.args = args
        self.device = device
        
        return args, device
    
    def get_dataset(self):
        args = self.args
        transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])
        if args.dataset == 'CelebA':
            train_dataset = datasets.CelebA('../../data', split='train', download=True, transform=transform)
            test_dataset = datasets.CelebA('../../data', split='test', transform=transform)
        elif args.dataset == 'MNIST':
            train_dataset = datasets.MNIST('../../data', train=True, download=True, transform=transform)
            test_dataset = datasets.MNIST('../../data', train=False, transform=transform)
        elif args.dataset == 'CIFAR':
            train_dataset = datasets.CIFAR10('../../data', train=True, download=True, transform=transform)
            test_dataset = datasets.CIFAR10('../../data', train=False, transform=transform)
        elif args.dataset == 'SVHN':
            train_dataset = datasets.SVHN('../../data', split='train', download=True, transform=transform)
            test_dataset = datasets.SVHN('../../data', split='train', transform=transform)
    
        kwargs = {'num_workers': 10, 'pin_memory': True} if args.cuda else {}
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
        
        return train_loader, test_loader

    def get_summary_writer(self, purge_step=0):
        args = self.args
        return SummaryWriter(log_dir='vae_logs/' + args.log_dir, purge_step=purge_step)

    def build_vae(self, device, args):
        if args.conv == 1:
            base_class = VAE_Conv
        else:
            base_class = VAE
    
        if args.dataset == 'MNIST':
            channels = 1
        else:
            channels = 3
    
        model = base_class(device, channels, args).to(device)
    
        return model

    def load_initial_checkpoint(self, model):
        args = self.args
        checkpoint = 'vae_logs/{}/checkpoint_{}.pt'.format(args.log_dir, str(args.epochs))
        if os.path.exists(checkpoint):
            print('Loading checkpoint from {}'.format(checkpoint))
            model.load_state_dict(torch.load(checkpoint))
